using System;
using System.Collections.Generic;
using UnityEngine.Perception.Randomization.Samplers;

namespace UnityEngine.Perception.Randomization.Parameters
{
    /// <summary>
    /// Generates samples by choosing one option from a list of choices
    /// </summary>
    /// <typeparam name="T">The sample type of the categorical parameter</typeparam>
    [Serializable]
    public class CategoricalParameter<T> : CategoricalParameterBase
    {
        [SerializeField] internal bool uniform = true;
        [SerializeField] List<T> m_Categories = new List<T>();
        UniformSampler m_Sampler = new UniformSampler(0f, 1f);
        float[] m_NormalizedProbabilities;

        /// <summary>
        /// Returns an IEnumerable that iterates over each sampler field in this parameter
        /// </summary>
        public override IEnumerable<ISampler> samplers
        {
            get { yield return m_Sampler; }
        }

        /// <summary>
        /// The sample type generated by this parameter
        /// </summary>
        public sealed override Type sampleType => typeof(T);

        /// <summary>
        /// Returns the number of stored categories
        /// </summary>
        /// <value>The number of stored categories</value>
        public int Count => m_Categories.Count;

        /// <summary>
        /// Returns the number of stored categories
        /// </summary>
        /// <value>The number of stored categories</value>
        /// <returns>Count of categories</returns>
        [Obsolete("GetCategoryCount method has been deprecated. Please use Count (UnityUpgradable)")]
        public int GetCategoryCount() => m_Categories.Count;

        /// <summary>
        /// Returns the category stored at the specified index
        /// </summary>
        /// <param name="index">The index of the category to lookup</param>
        /// <returns>The category stored at the specified index</returns>
        public T GetCategory(int index) => m_Categories[index];

        /// <summary>
        /// Returns the probability value stored at the specified index
        /// </summary>
        /// <param name="index">The index of the probability value to lookup</param>
        /// <returns>The probability value stored at the specified index</returns>
        public float GetProbability(int index) => probabilities[index];

        /// <summary>
        /// Updates this parameter's list of categorical options
        /// </summary>
        /// <param name="categoricalOptions">The categorical options to configure</param>
        public void SetOptions(IEnumerable<T> categoricalOptions)
        {
            m_Categories.Clear();
            probabilities.Clear();
            foreach (var category in categoricalOptions)
                AddOption(category, 1f);
            NormalizeProbabilities();
        }

        /// <summary>
        /// Updates this parameter's list of categorical options
        /// </summary>
        /// <param name="categoricalOptions">The categorical options to configure</param>
        public void SetOptions(IEnumerable<(T, float)> categoricalOptions)
        {
            m_Categories.Clear();
            probabilities.Clear();
            foreach (var(category, probability) in categoricalOptions)
                AddOption(category, probability);
            NormalizeProbabilities();
        }

        void AddOption(T option, float probability)
        {
            m_Categories.Add(option);
            probabilities.Add(probability);
        }

        /// <summary>
        /// Returns a list of the potential categories this parameter can generate
        /// </summary>
        public IReadOnlyList<(T, float)> categories
        {
            get
            {
                var catOptions = new List<(T, float)>(m_Categories.Count);
                for (var i = 0; i < m_Categories.Count; i++)
                    catOptions.Add((m_Categories[i], probabilities[i]));
                return catOptions;
            }
        }

        /// <summary>
        /// Validates the categorical probabilities assigned to this parameter
        /// </summary>
        /// <exception cref="ParameterValidationException"></exception>
        public override void Validate()
        {
            base.Validate();

            // Check for a non-zero amount of specified categories
            if (m_Categories.Count == 0)
                throw new ParameterValidationException("No options added to categorical parameter");

            // Check for duplicate categories
            var uniqueCategories = new HashSet<T>();
            foreach (var option in m_Categories)
                if (uniqueCategories.Contains(option))
                    throw new ParameterValidationException($"Duplicate categories in {typeof(T)}: {option}");
                else
                    uniqueCategories.Add(option);

            // Check if the number of specified probabilities is different from the number of listed categories
            if (!uniform)
            {
                if (probabilities.Count != m_Categories.Count)
                    throw new ParameterValidationException("Number of options must be equal to the number of probabilities");
                NormalizeProbabilities();
            }
        }

        /// <summary>
        /// Generates a sample
        /// </summary>
        /// <returns>The generated sample</returns>
        public T Sample()
        {
            var randomValue = m_Sampler.Sample();
            if (uniform)
            {
                var index = (int)(randomValue * m_Categories.Count);
                index = index == m_Categories.Count ? index - 1 : index;
                return m_Categories[index];
            }
            return m_Categories[BinarySearch(m_NormalizedProbabilities, randomValue)];
        }

        /// <summary>
        /// Generates a generic sample
        /// </summary>
        /// <returns>The generated sample</returns>
        public override object GenericSample()
        {
            return Sample();
        }

        void NormalizeProbabilities()
        {
            var totalProbability = 0f;
            for (var i = 0; i < probabilities.Count; i++)
            {
                var probability = probabilities[i];
                if (probability < 0f)
                    throw new ParameterValidationException($"Found negative probability at index {i}");
                totalProbability += probability;
            }

            if (totalProbability <= 0f)
                throw new ParameterValidationException("Total probability must be greater than 0");

            var sum = 0f;
            m_NormalizedProbabilities = new float[probabilities.Count];
            for (var i = 0; i < probabilities.Count; i++)
            {
                sum += probabilities[i] / totalProbability;
                m_NormalizedProbabilities[i] = sum;
            }
        }

        internal static int BinarySearch(float[] normalizedProbabilities, float key)
        {
            var minNum = 0;
            var maxNum = normalizedProbabilities.Length - 1;

            while (minNum <= maxNum)
            {
                var mid = minNum + (maxNum - minNum) / 2;
                // ReSharper disable once CompareOfFloatsByEqualityOperator
                if (key == normalizedProbabilities[mid])
                {
                    return mid;
                }
                if (key < normalizedProbabilities[mid])
                {
                    maxNum = mid - 1;
                }
                else
                {
                    minNum = mid + 1;
                }
            }

            // When the minNum exceeds the length of input array, return last index.
            if (minNum >= normalizedProbabilities.Length)
            {
                return normalizedProbabilities.Length - 1;
            }
            return minNum;
        }
    }
}
